Example: Holdout set¶
This example shows when and how to use ATOM's holdout set in an exploration pipeline.
The data used is a variation on the Australian weather dataset from Kaggle. You can download it from here. The goal of this dataset is to predict whether or not it will rain tomorrow training a binary classifier on target RainTomorrow.
Load the data¶
# Import packages
import pandas as pd
from atom import ATOMClassifier
# Load data
X = pd.read_csv("docs_source/examples/datasets/weatherAUS.csv")
# Let's have a look
X.head()
| Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | WindDir3pm | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | MelbourneAirport | 18.0 | 26.9 | 21.4 | 7.0 | 8.9 | SSE | 41.0 | W | SSE | ... | 95.0 | 54.0 | 1019.5 | 1017.0 | 8.0 | 5.0 | 18.5 | 26.0 | Yes | 0 |
| 1 | Adelaide | 17.2 | 23.4 | 0.0 | NaN | NaN | S | 41.0 | S | WSW | ... | 59.0 | 36.0 | 1015.7 | 1015.7 | NaN | NaN | 17.7 | 21.9 | No | 0 |
| 2 | Cairns | 18.6 | 24.6 | 7.4 | 3.0 | 6.1 | SSE | 54.0 | SSE | SE | ... | 78.0 | 57.0 | 1018.7 | 1016.6 | 3.0 | 3.0 | 20.8 | 24.1 | Yes | 0 |
| 3 | Portland | 13.6 | 16.8 | 4.2 | 1.2 | 0.0 | ESE | 39.0 | ESE | ESE | ... | 76.0 | 74.0 | 1021.4 | 1020.5 | 7.0 | 8.0 | 15.6 | 16.0 | Yes | 1 |
| 4 | Walpole | 16.4 | 19.9 | 0.0 | NaN | NaN | SE | 44.0 | SE | SE | ... | 78.0 | 70.0 | 1019.4 | 1018.9 | NaN | NaN | 17.4 | 18.1 | No | 0 |
5 rows × 22 columns
Run the pipeline¶
# Initialize atom specifying a fraction of the dataset for holdout
atom = ATOMClassifier(X, n_rows=0.5, holdout_size=0.2, verbose=2)
<< ================== ATOM ================== >> Configuration ==================== >> Algorithm task: Binary classification. Dataset stats ==================== >> Shape: (56877, 22) Train set size: 42658 Test set size: 14219 Holdout set size: 14219 ------------------------------------- Memory: 10.01 MB Scaled: False Missing values: 126822 (10.1%) Categorical features: 5 (23.8%) Duplicates: 15 (0.0%)
# The test and holdout fractions are split after subsampling the dataset
# Also note that the holdout data set is not a part of atom's dataset
print("Length loaded data:", len(X))
print("Length dataset + holdout:", len(atom.dataset) + len(atom.holdout))
Length loaded data: 142193 Length dataset + holdout: 71096
atom.impute()
atom.encode()
Fitting Imputer... Imputing missing values... --> Dropping 258 samples due to missing values in column MinTemp. --> Dropping 127 samples due to missing values in column MaxTemp. --> Dropping 553 samples due to missing values in column Rainfall. --> Dropping 24308 samples due to missing values in column Evaporation. --> Dropping 27187 samples due to missing values in column Sunshine. --> Dropping 3739 samples due to missing values in column WindGustDir. --> Dropping 3712 samples due to missing values in column WindGustSpeed. --> Dropping 3995 samples due to missing values in column WindDir9am. --> Dropping 1508 samples due to missing values in column WindDir3pm. --> Dropping 539 samples due to missing values in column WindSpeed9am. --> Dropping 1077 samples due to missing values in column WindSpeed3pm. --> Dropping 706 samples due to missing values in column Humidity9am. --> Dropping 1447 samples due to missing values in column Humidity3pm. --> Dropping 5610 samples due to missing values in column Pressure9am. --> Dropping 5591 samples due to missing values in column Pressure3pm. --> Dropping 21520 samples due to missing values in column Cloud9am. --> Dropping 22921 samples due to missing values in column Cloud3pm. --> Dropping 365 samples due to missing values in column Temp9am. --> Dropping 1106 samples due to missing values in column Temp3pm. --> Dropping 553 samples due to missing values in column RainToday. Fitting Encoder... Encoding categorical columns... --> Target-encoding feature Location. Contains 26 classes. --> Target-encoding feature WindGustDir. Contains 16 classes. --> Target-encoding feature WindDir9am. Contains 16 classes. --> Target-encoding feature WindDir3pm. Contains 16 classes. --> Ordinal-encoding feature RainToday. Contains 2 classes.
# Unlike train and test, the holdout data set is not transformed until used for predictions
atom.holdout
| Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | WindDir3pm | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 22540 | NorahHead | 15.8 | 23.7 | 0.4 | NaN | NaN | SSW | 50.0 | NW | NaN | ... | 79.0 | 80.0 | 1012.4 | 1009.6 | NaN | NaN | 18.4 | 18.9 | No | 0 |
| 22541 | Brisbane | 13.0 | 24.1 | 0.0 | 3.2 | 3.6 | W | 24.0 | SW | WSW | ... | 53.0 | 27.0 | 1019.9 | 1015.9 | 7.0 | 8.0 | 17.3 | 22.1 | No | 0 |
| 22542 | MountGambier | 14.7 | 36.2 | 0.0 | 7.2 | 12.5 | S | 33.0 | N | SSW | ... | 52.0 | 27.0 | 1018.8 | 1017.4 | 7.0 | 2.0 | 25.2 | 35.4 | No | 0 |
| 22543 | Launceston | 12.3 | 21.4 | 0.0 | NaN | NaN | NNW | 52.0 | NNW | NNW | ... | 62.0 | 60.0 | NaN | NaN | 5.0 | 8.0 | 16.2 | 20.4 | No | 0 |
| 22544 | MountGinini | 3.2 | 10.0 | 0.0 | NaN | NaN | WSW | 52.0 | WSW | WSW | ... | 97.0 | 95.0 | NaN | NaN | NaN | NaN | 6.5 | 8.4 | No | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 36754 | MountGinini | 1.6 | 4.4 | 0.0 | NaN | NaN | E | 52.0 | E | E | ... | 100.0 | 100.0 | NaN | NaN | NaN | NaN | 2.7 | 2.6 | No | 1 |
| 36755 | WaggaWagga | 9.9 | 21.8 | 0.0 | 4.6 | 5.7 | WSW | 35.0 | S | SW | ... | 57.0 | 36.0 | 1015.5 | 1013.7 | 7.0 | 7.0 | 17.0 | 21.3 | No | 0 |
| 36756 | Walpole | 8.8 | 16.3 | 0.8 | NaN | NaN | NNW | 37.0 | NNE | N | ... | 84.0 | 79.0 | 1018.4 | 1013.5 | NaN | NaN | 11.0 | 14.6 | No | 1 |
| 36757 | Dartmoor | 8.7 | 15.5 | 2.0 | 1.4 | 5.4 | S | 30.0 | WSW | SSW | ... | 100.0 | 94.0 | 1018.6 | 1020.0 | NaN | NaN | 12.9 | 12.8 | Yes | 0 |
| 36758 | SydneyAirport | 16.8 | 22.6 | 8.4 | 5.0 | 3.8 | S | 57.0 | WNW | S | ... | 79.0 | 75.0 | 1013.2 | 1013.7 | 8.0 | 6.0 | 17.1 | 18.8 | Yes | 0 |
14219 rows × 22 columns
atom.run(models=["GNB", "LR", "RF"])
Training ========================= >> Models: GNB, LR, RF Metric: f1 Results for GaussianNB: Fit --------------------------------------------- Train evaluation --> f1: 0.604 Test evaluation --> f1: 0.6063 Time elapsed: 0.209s ------------------------------------------------- Time: 0.209s Results for LogisticRegression: Fit --------------------------------------------- Train evaluation --> f1: 0.6188 Test evaluation --> f1: 0.6162 Time elapsed: 0.323s ------------------------------------------------- Time: 0.323s Results for RandomForest: Fit --------------------------------------------- Train evaluation --> f1: 1.0 Test evaluation --> f1: 0.6084 Time elapsed: 4.533s ------------------------------------------------- Time: 4.533s Final results ==================== >> Total time: 5.734s ------------------------------------- GaussianNB --> f1: 0.6063 LogisticRegression --> f1: 0.6162 ! RandomForest --> f1: 0.6084 ~
atom.plot_prc()
# Based on the results on the test set, we select the best model for further tuning
atom.run("lr_tuned", n_trials=10)
Training ========================= >> Models: LR_tuned Metric: f1 Running hyperparameter tuning for LogisticRegression... | trial | penalty | C | solver | max_iter | l1_ratio | f1 | best_f1 | time_trial | time_ht | state | | ----- | ------- | ------- | ------- | -------- | -------- | ------- | ------- | ---------- | ------- | -------- | | 0 | None | 0.1893 | sag | 540 | 0.4 | 0.6096 | 0.6096 | 0.797s | 0.797s | COMPLETE | | 1 | l2 | 0.6275 | newto.. | 150 | 0.7 | 0.6101 | 0.6101 | 0.637s | 1.433s | COMPLETE | | 2 | l1 | 0.7457 | libli.. | 740 | 0.7 | 0.6114 | 0.6114 | 0.815s | 2.248s | COMPLETE | | 3 | l2 | 0.0759 | newto.. | 290 | 0.4 | 0.6204 | 0.6204 | 0.634s | 2.882s | COMPLETE | | 4 | l2 | 0.2122 | newto.. | 730 | 0.9 | 0.6273 | 0.6273 | 0.635s | 3.516s | COMPLETE | | 5 | l2 | 0.0017 | lbfgs | 260 | 1.0 | 0.589 | 0.6273 | 0.581s | 4.097s | COMPLETE | | 6 | l2 | 0.0137 | sag | 130 | 0.4 | 0.6092 | 0.6273 | 0.615s | 4.711s | COMPLETE | | 7 | None | 0.0014 | sag | 640 | 0.1 | 0.5909 | 0.6273 | 0.725s | 5.436s | COMPLETE | | 8 | l2 | 0.0224 | sag | 500 | 1.0 | 0.6226 | 0.6273 | 0.653s | 6.089s | COMPLETE | | 9 | l1 | 0.1594 | saga | 630 | 0.2 | 0.6236 | 0.6273 | 0.810s | 6.898s | COMPLETE | Hyperparameter tuning --------------------------- Best trial --> 4 Best parameters: --> penalty: l2 --> C: 0.2122 --> solver: newton-cg --> max_iter: 730 --> l1_ratio: 0.9 Best evaluation --> f1: 0.6273 Time elapsed: 6.898s Fit --------------------------------------------- Train evaluation --> f1: 0.6188 Test evaluation --> f1: 0.6172 Time elapsed: 0.352s ------------------------------------------------- Time: 7.251s Final results ==================== >> Total time: 7.461s ------------------------------------- LogisticRegression --> f1: 0.6172
Analyze the results¶
We already used the test set to choose the best model for futher tuning, so this set is no longer truly independent. Although it may not be directly visible in the results, using the test set now to evaluate the tuned LR model would be a mistake, since it carries a bias. For this reason, we have set apart an extra, indepedent set to validate the final model: the holdout set. If we are not going to use the test set for validation, we might as well use it to train the model and so optimize the use of the available data. Use the full_train method for this.
# Re-train the model on the full dataset (train + test)
atom.lr_tuned.full_train()
Fit --------------------------------------------- Train evaluation --> f1: 0.6185 Test evaluation --> f1: 0.6185 Time elapsed: 0.717s
# Evaluate on the holdout set
atom.lr_tuned.evaluate(rows="holdout")
accuracy 0.8577 ap 0.7473 ba 0.7480 f1 0.6352 jaccard 0.4654 mcc 0.5606 precision 0.7559 recall 0.5477 auc 0.8873 Name: LR_tuned, dtype: float64
atom.lr_tuned.plot_prc(rows="holdout", legend="upper right")